{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S3-tJZy35Ck_"
   },
   "source": [
    "# Trainer\n",
    "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n",
    "\n",
    "<center>\n",
    "<img src=../_static/images/structure.svg></img>\n",
    "</center>\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ifsEQMzZ6mmz"
   },
   "source": [
    "## Usages\n",
    "In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in  on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XfsuU2AAE52C"
   },
   "source": [
    "### Pseudocode\n",
    "<center>\n",
    "<img src=../_static/images/pseudocode_off_policy.svg></img>\n",
    "</center>\n",
    "\n",
    "For the on-policy trainer, the main difference is that we clear the buffer after Line 10."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hcp_o0CCFz12"
   },
   "source": [
    "### Training without trainer\n",
    "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n",
    "\n",
    "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "do-xZ-8B7nVH",
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "hide-cell",
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "import gymnasium as gym\n",
    "import torch\n",
    "\n",
    "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
    "from tianshou.env import DummyVectorEnv\n",
    "from tianshou.policy import PGPolicy\n",
    "from tianshou.trainer import OnpolicyTrainer\n",
    "from tianshou.utils.net.common import Net\n",
    "from tianshou.utils.net.discrete import Actor\n",
    "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_env_num = 4\n",
    "# Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
    "buffer_size = 2000\n",
    "\n",
    "\n",
    "# Create the environments, used for training and evaluation\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n",
    "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
    "\n",
    "# Create the Policy instance\n",
    "assert env.observation_space.shape is not None\n",
    "net = Net(\n",
    "    env.observation_space.shape,\n",
    "    hidden_sizes=[\n",
    "        16,\n",
    "    ],\n",
    ")\n",
    "\n",
    "assert isinstance(env.action_space, gym.spaces.Discrete)\n",
    "actor = Actor(net, env.action_space.n)\n",
    "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n",
    "\n",
    "# We choose to use REINFORCE algorithm, also known as Policy Gradient\n",
    "policy: PGPolicy = PGPolicy(\n",
    "    actor=actor,\n",
    "    optim=optim,\n",
    "    dist_fn=torch.distributions.Categorical,\n",
    "    action_space=env.action_space,\n",
    "    action_scaling=False,\n",
    ")\n",
    "\n",
    "# Create the replay buffer and the collector\n",
    "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
    "test_collector = Collector[CollectStats](policy, test_envs)\n",
    "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wiEGiBgQIiFM"
   },
   "source": [
    "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JMUNPN5SI_kd",
    "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d"
   },
   "outputs": [],
   "source": [
    "train_collector.reset()\n",
    "train_envs.reset()\n",
    "test_collector.reset()\n",
    "test_envs.reset()\n",
    "replayBuffer.reset()\n",
    "\n",
    "n_episode = 10\n",
    "for _i in range(n_episode):\n",
    "    # for test collector, we set the wrapped torch module to evaluation mode\n",
    "    # by default, the policy object itself is not within the training step\n",
    "    with torch_train_mode(policy, enabled=False):\n",
    "        evaluation_result = test_collector.collect(n_episode=n_episode)\n",
    "    print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
    "    # for collecting data for training, the policy object should be within the training step\n",
    "    # (affecting e.g. whether the policy is stochastic or deterministic)\n",
    "    with policy_within_training_step(policy):\n",
    "        train_collector.collect(n_step=2000)\n",
    "        # 0 means taking all data stored in train_collector.buffer\n",
    "        # for updating the policy, the wrapped torch module should be in training mode\n",
    "        with torch_train_mode(policy):\n",
    "            policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
    "    train_collector.reset_buffer(keep_statistics=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QXBHIBckMs_2"
   },
   "source": [
    "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p-7U_cwgF5Ej"
   },
   "source": [
    "### Training with trainer\n",
    "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "vcvw9J8RNtFE",
    "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5",
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "train_collector.reset()\n",
    "train_envs.reset()\n",
    "test_collector.reset()\n",
    "test_envs.reset()\n",
    "replayBuffer.reset()\n",
    "\n",
    "result = OnpolicyTrainer(\n",
    "    policy=policy,\n",
    "    train_collector=train_collector,\n",
    "    test_collector=test_collector,\n",
    "    max_epoch=10,\n",
    "    step_per_epoch=1,\n",
    "    repeat_per_collect=1,\n",
    "    episode_per_test=10,\n",
    "    step_per_collect=2000,\n",
    "    batch_size=512,\n",
    ").run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "result.pprint_asdict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_j3aUJZQ7nml"
   },
   "source": [
    "## Further Reading\n",
    "### Logger usages\n",
    "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.org/en/master/03_api/utils/logger/base.html#tianshou.utils.logger.base.BaseLogger) for details.\n",
    "\n",
    "### Learn more about the APIs of Trainers\n",
    "[documentation](https://tianshou.org/en/master/03_api/trainer/index.html)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "S3-tJZy35Ck_",
    "XfsuU2AAE52C",
    "p-7U_cwgF5Ej",
    "_j3aUJZQ7nml"
   ],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
